import os, time, csv, itertools, psutil
import subprocess, resource

from .common import *
from .collect import *


class NodeCollect(BaseCollect):
    def __init__(self, output, interval, collects, cgroupPath="", pid=""):
        super().__init__(interval)
        self.pid = pid
        self.supports = supHelper.getNodeSupports()
        self.cgroupPath = cgroupPath
        self.collects = self.init_collects(collects)
        self.outputFile = "/".join([output, "node.data"])  # TODO: constant
        self.prepare_output_file(collects)

    def init_collects(self, collects):
        """create real collection classes"""
        collections = []
        classes = supHelper.getNodeClasses()
        for c in collects:
            if c not in self.supports:
                continue
            if c == "rdt" and self.pid == "":
                continue
            if c == "pressure" and self.cgroupPath == "":
                continue

            if c == "pressure":
                collections.append(
                    globals()[classes[c]](
                        v=None, ctrID=None, cgroupPath=self.cgroupPath
                    )
                )
            else:
                collections.append(globals()[classes[c]](pid=self.pid))
        return collections

    def collect(self):
        # TODO: do net and io collects first, because both sleep 1s, or run them in new threads? so we can use accurate time.
        while True:
            data = []
            for c in self.collects:
                data.extend(c.collect())
            self.write(data, self.outputFile)
            time.sleep(self.interval)

    def prepare_output_file(self, collects):
        columns = ["timestamp"]
        classes = supHelper().getNodeClasses()
        for c in collects:
            if c not in self.supports:
                continue
            if c == "rdt" and self.pid == "":
                continue
            if c == "pressure" and self.cgroupPath == "":
                continue

            if c == "io" or c == "memory":
                cols = globals()[classes[c]].get_columns(self.pid)
            else:
                cols = globals()[classes[c]].get_columns()

            columns.extend(cols)

        if not os.path.exists(self.outputFile):
            # create file, don't use mknod bcoz Macos doesnt support it
            with open(self.outputFile, "w"):
                pass

        with open(self.outputFile, "w") as csvfile:
            csvWriter = csv.writer(csvfile, delimiter=",")
            csvWriter.writerow(columns)


class NodeCpuCollect(object):
    def __init__(self, pid=""):
        self.pid = pid

    @classmethod
    def get_columns(cls):
        cpus = psutil.cpu_count()
        columns = []
        columns.extend(["cpu-" + str(i) + "-steal %" for i in range(cpus)])
        columns.extend(["cpu-" + str(i) + "-usage %" for i in range(cpus)])
        columns.extend(["cpu-loadavg" + i + " %" for i in ["1", "5", "15"]])
        return columns

    def collect_cpu_usage(self):
        # block 1 second
        cputimes = psutil.cpu_times_percent(interval=1, percpu=True)
        loadavg = [
            format(x / psutil.cpu_count() * 100, ".3f") for x in psutil.getloadavg()
        ]
        steal = [format(scputimes.steal,".3f")  for scputimes in cputimes]
        usage = [format(scputimes.user + scputimes.system, ".3f") for scputimes in cputimes]
        return steal + usage + loadavg

    def collect(self):
        return self.collect_cpu_usage()


class NodeNetCollect(object):
    interfaces = []

    def __init__(self, pid):
        self.pid = pid
        self.columns = NodeNetCollect.get_columns()

    @classmethod
    def get_TCP_columns(cls):
        return [
            "TCPSackRecovery",
            "TCPSACKReneging",
            "TCPAbortOnMemory",
            "TCPAbortOnTimeout",
            "TCPMemoryPressures",
            "ListenOverflows",
        ]

    @classmethod
    def get_net_columns(cls):
        cols = []
        cls.interfaces = [
            nic for nic, snicstats in psutil.net_if_stats().items() if snicstats.isup
        ]

        stats_info = psutil.net_if_stats()
        for nic, snicstats in stats_info.items():
            if snicstats.isup:
                # e.g."docker0.recv.bytes", "docker0.recv.packets", "docker0.errin", "docker0.dropin"
                recv = ["recv.bytes", "recv.packets", "errin", "dropin"]
                sent = ["sent.bytes", "sent.packets", "errout", "dropout"]
                cols.extend(".".join(col) for col in itertools.product([nic], recv))
                cols.extend(".".join(col) for col in itertools.product([nic], sent))
        return cols

    @classmethod
    def get_columns(cls):
        return NodeNetCollect.get_TCP_columns() + NodeNetCollect.get_net_columns()

    def collect_TCPS(self):
        if not self.pid:
            path = "/".join(["/proc", "net", "netstat"])
        else:
            path = "/".join(["/proc", str(self.pid), "net", "netstat"])

        # collect netsat TCPS*
        with open(path, mode="r") as f:
            lines = f.readlines()
            keys, values = lines[0].rstrip().split(), lines[1].rstrip().split()
            data = dict(zip(keys, values))

        return [data[target] for target in NodeNetCollect.get_TCP_columns()]

    def collect_net_io(self):
        # get data with interval 1s
        before_io_counts = psutil.net_io_counters(pernic=True)
        time.sleep(1)
        after_io_counts = psutil.net_io_counters(pernic=True)

        data = []
        for nic in NodeNetCollect.interfaces:
            if nic in before_io_counts and nic in after_io_counts:
                before_io, after_io = before_io_counts[nic], after_io_counts[nic]
                data.extend(
                    [
                        after_io.bytes_recv - before_io.bytes_recv,
                        after_io.packets_recv - before_io.packets_recv,
                        after_io.errin - before_io.errin,
                        after_io.dropin - before_io.dropin,
                    ]
                )
                data.extend(
                    [
                        after_io.bytes_sent - before_io.bytes_sent,
                        after_io.packets_sent - before_io.packets_sent,
                        after_io.errout - before_io.errout,
                        after_io.dropout - before_io.dropout,
                    ]
                )
            else:
                return data.extend([-1] * 8)
        return data

    def collect(self):
        return self.collect_TCPS() + self.collect_net_io()


class NodeMemCollect(object):
    def __init__(self, pid):
        self.pid = pid
        self.columns = NodeMemCollect.get_columns(self.pid)

    @classmethod
    def get_columns(cls, pid):
        # support pid lists?
        return ["memory_usage %", str(pid) + ".rss"]

    def collect_usage(self):
        mem = psutil.virtual_memory()
        return [format(mem.used / mem.total, ".3f")]

    def collect_rss(self):
        if not self.pid:
            print("statm requires pid")
            return [-1]

        try:
            p = psutil.Process(self.pid)
            output = p.memory_full_info()
        except psutil.Error:
            print(
                "%s : psutil.Process failed"
                % (time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
            )
            return [-1]

        return [output.rss]

    def collect(self):
        return self.collect_usage() + self.collect_rss()


class NodeIoCollect(object):
    """device iostat for the node

    Reference: https://github.com/giampaolo/psutil/blob/master/scripts/iotop.py
    """

    def __init__(self, pid=""):
        self.pid = pid

    @classmethod
    def get_columns(cls, pid):
        columns = []
        total = [
            "total_disk_usage",
            "total_disks_read_bps",
            "total_disks_read_iops",
            "total_disk_write_bps",
            "total_disk_write_iops",
        ]
        pid = str(pid)
        p = [
            pid + "_disk_read_bps",
            pid + "_disk_read_iops",
            pid + "_disk_write_bps",
            pid + "_disk_write_iops",
        ]
        columns.extend(total)
        columns.extend(p)
        return columns

    def collect_util(self):
        usage = psutil.disk_usage("/")
        return [usage.percent]

    def collect_io(self):
        # Todo: support pid list?
        use_pid = False
        if self.pid:
            use_pid = True
            try:
                p = psutil.Process(self.pid)
                p._before = p.io_counters()
            except psutil.Error:
                print(
                    "%s : psutil.Process failed"
                    % (time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
                )
                use_pid = False
        disks_before = psutil.disk_io_counters()

        interval = 1
        time.sleep(interval)

        if self.pid:
            with p.oneshot():
                try:
                    p._after = p.io_counters()
                except (psutil.NoSuchProcess, psutil.ZombieProcess):
                    print(
                        "%s : psutil.Process failed"
                        % (time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
                    )
                    use_pid = False
        disks_after = psutil.disk_io_counters()

        # pid disk io staff
        pid_io = [-1] * 4
        if use_pid:
            p._read_bps = p._after.read_bytes - p._before.read_bytes
            p._read_iops = p._after.read_count - p._before.read_count
            p._write_bps = p._after.write_bytes - p._before.write_bytes
            p._write_iops = p._after.write_count - p._before.write_count
            pid_io = [p._read_bps, p._read_iops, p._write_bps, p._write_iops]

        # total disk io staff
        disks_read_bps = disks_after.read_bytes - disks_before.read_bytes
        disks_read_iops = disks_after.read_count - disks_before.read_count
        disks_write_bps = disks_after.write_bytes - disks_before.write_bytes
        disks_write_iops = disks_after.write_count - disks_before.write_count
        total_io = [disks_read_bps, disks_read_iops, disks_write_bps, disks_write_iops]

        return total_io + pid_io

    def collect(self):
        return self.collect_util() + self.collect_io()


class NodeProcCollect(object):
    """get processcs counts in uninterruptible state ?
    (1) Uninterruptible state:
    the process is in a sleep state, but the process is uninterruptible at the moment.
    Uninterruptible means that the process does not respond to asynchronous signals.

    TODO: get D process counts? psutil.STATUS_SLEEPING?  psutil.STATUS_DISK_SLEEP?
    """

    def __init__(self, pid=""):
        self.pid = pid

    @classmethod
    def get_columns(cls):
        return []

    def collect(self):
        return []
